#!/bin/env python
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

#----------
# Parameters
#----------
labs = ["Observations (1977\u20142015)","HiFLOR (AllForc, 1977\u20142015)", "SPEAR (AllForc, 1977\u20142015)",'HiFLOR (NatForc, 1977\u20142015)',"SPEAR (NatForc, 1977\u20142015)", "", "HiFLOR (AllForc, 1977\u20142050)", "SPEAR (AllForc, 1977\u20142050)","SPEAR (NatForc, 1977\u20142050)"]

#---------
# Subroutnes
#---------
def linreg(X,Y):
    """
     return a,b in solution to y = ax + b such that root mean square distance
     between trend line and original points is minimized
    """
    N = len(X)
    Sx = Sy = Sxx = Syy = Sxy = 0.0

    for x, y in zip(X, Y):
        Sx = Sx + x
        Sy = Sy + y
        Sxx = Sxx + x*x
        Syy = Syy + y*y
        Sxy = Sxy + x*y
    det = Sxx * N - Sx * Sx
    aa = (Sxy * N - Sy * Sx)/det
    bb = (Sxx * Sy - Sx * Sxy)/det
    YY2 = aa * X + bb
    return X, YY2, aa

def mk_test(x, alpha = 0.05):
    from scipy.stats import norm
    """
    this perform the MK (Mann-Kendall) test to check if the trend is present in
    data or not

    Input:
        x:   a vector of data
        alpha: significance level

    Output:
        h: True (if trend is present) or False (if trend is absence)
        p: p value of the sifnificance test

    Examples
    --------
      >>> x = np.random.rand(100)
      >>> h,p = mk_test(x,0.05)  # meteo.dat comma delimited
    """
    n = len(x)

    # calculate S
    s = 0
    for k in range(n-1):
        for j in range(k+1,n):
            s += np.sign(x[j] - x[k])

    # calculate the unique data
    unique_x = np.unique(x)
    g = len(unique_x)

    # calculate the var(s)
    if n == g: # there is no tie
        var_s = (n*(n-1)*(2*n+5))/18
    else: # there are some ties in data
        tp = np.zeros(unique_x.shape)
        for i in range(len(unique_x)):
            tp[i] = sum(unique_x[i] == x)
        var_s = (n*(n-1)*(2*n+5) + np.sum(tp*(tp-1)*(2*tp+5)))/18

    if s>0:
        z = (s - 1)/np.sqrt(var_s)
    elif s == 0:
        z = 0
    elif s<0:
        z = (s + 1)/np.sqrt(var_s)

    # calculate the p_value
    p = 2*(1-norm.cdf(abs(z))) # two tail test
    h = abs(z) > norm.ppf(1-alpha/2)

    return h, p

def cal_trend_pval(syear,eyear,years,data, factor=1.0):
    idx = np.where((years>=syear)&(years<=eyear))
    year1_ = years[idx[0]]
    data1_ = data[idx[0]]
    h1, pval1 = mk_test(data1_, alpha=0.05)
    xx1, yy1, aa1 = linreg(year1_,data1_)
    aa1 = aa1*factor
    return aa1, pval1

def cal_trends(syear,eyear,years,data,factor=1.0):
    tmax,emax = np.shape(data)
    trends = np.zeros(emax)
    pvals = np.zeros(emax)
    for ee in range(emax):
        trends[ee], pvals[ee] = cal_trend_pval(syear,eyear,years,data[:,ee])
        trends[ee] = trends[ee] * factor
    return trends, pvals

#----------
# Read Data
#----------
datafl1 = "./Data/Observations/timeseries_SF.json"
alldata1j = pd.read_json(datafl1,typ='series')
alldata1a = alldata1j.loc[1977:2015]

datafl2 = "./Data/HiFLOR_AllForc/timeseries_test_SF.json"
alldata2j = pd.read_json(datafl2)
alldata2a = alldata2j.loc[1977:2015]
alldata2b = alldata2j.loc[1977:2050]

datafl3 = "./Data/HiFLOR_NatForc/timeseries_test_SF.json"
alldata3j = pd.read_json(datafl3)
alldata3a = alldata3j.loc[1977:2015]
alldata3b = alldata3j.loc[1977:2020]

datafl4 = "./Data/SPEAR_AllForc/timeseries_test_SF.json"
alldata4j = pd.read_json(datafl4)
alldata4a = alldata4j.loc[1977:2015]
alldata4b = alldata4j.loc[1977:2050]

datafl5 = "./Data/SPEAR_NatForc/timeseries_test_SF.json"
alldata5j = pd.read_json(datafl5)
alldata5a = alldata5j.loc[1977:2015]
alldata5b = alldata5j.loc[1977:2050]

#Observed 1977-2015
years1a = alldata1a.index
alldata1a = alldata1a.to_numpy()

#HiFLOR AllForc 1977-2015
years2a = alldata2a.index
alldata2a = alldata2a.to_numpy()
ensmean2a = alldata2a.mean(axis=1)
ensmax2a = alldata2a.max(axis=1)
ensmin2a = alldata2a.min(axis=1)

#HiFLOR AllForc 1977-2050
years2b = alldata2b.index
alldata2b = alldata2b.to_numpy()
ensmean2b = alldata2b.mean(axis=1)
ensmax2b = alldata2b.max(axis=1)
ensmin2b = alldata2b.min(axis=1)

#HiFLOR NATURAL 1977-2015
years3a = alldata3a.index
alldata3a = alldata3a.to_numpy()
ensmean3a = alldata3a.mean(axis=1)
ensmax3a = alldata3a.max(axis=1)
ensmin3a = alldata3a.min(axis=1)

#HiFLOR NATURAL 1977-2020
years3b = alldata3b.index
alldata3b = alldata3b.to_numpy()
ensmean3b = alldata3b.mean(axis=1)
ensmax3b = alldata3b.max(axis=1)
ensmin3b = alldata3b.min(axis=1)

#SPEAR AllForc 1977-2015
years4a = alldata4a.index
alldata4a = alldata4a.to_numpy()
ensmean4a = alldata4a.mean(axis=1)
ensmax4a = alldata4a.max(axis=1)
ensmin4a = alldata4a.min(axis=1)

#SPEAR AllForc 1977-2050
years4b = alldata4b.index
alldata4b = alldata4b.to_numpy()
ensmean4b = alldata4b.mean(axis=1)
ensmax4b = alldata4b.max(axis=1)
ensmin4b = alldata4b.min(axis=1)

#SPEAR NATURAL 1977-2015
years5a = alldata5a.index
alldata5a = alldata5a.to_numpy()
ensmean5a = alldata5a.mean(axis=1)
ensmax5a = alldata5a.max(axis=1)
ensmin5a = alldata5a.min(axis=1)

#SPEAR NATURAL 1977-2100
years5b = alldata5b.index
alldata5b = alldata5b.to_numpy()
ensmean5b = alldata5b.mean(axis=1)
ensmax5b = alldata5b.max(axis=1)
ensmin5b = alldata5b.min(axis=1)

trend={}
pval={}
trend_data={}

trend['Obs'], pval['Obs'] = cal_trend_pval(1977,2015,years1a,alldata1a)
trend_data['Obs'] = [trend['Obs']]

trend['dummy'] = np.nan
trend_data['dummy'] = [np.nan]

#--HiFLOR_AllForc
trend['HiFLOR_all_p'], pval['HiFLOR_all_p'] = cal_trend_pval(1977,2015,years2a,ensmean2a)
trend_data['HiFLOR_all_p'], dummy = cal_trends(1977,2015,years2a,alldata2a)

trend['HiFLOR_all_f'], pval['HiFLOR_all_f'] = cal_trend_pval(1977,2050,years2b,ensmean2b)
trend_data['HiFLOR_all_f'], dummy = cal_trends(1977,2050,years2b,alldata2b)

#--HIFLOR_NatForc
trend['HiFLOR_nat_p'], pval['HiFLOR_nat_p'] = cal_trend_pval(1977,2015,years3a,ensmean3a)
trend_data['HiFLOR_nat_p'], dummy = cal_trends(1977,2015,years3a,alldata3a)

trend['HiFLOR_nat_f'], pval['HiFLOR_nat_f'] = cal_trend_pval(1977,2020,years3b,ensmean3b)
trend_data['HiFLOR_nat_f'], dummy = cal_trends(1977,2020,years3b,alldata3b)

#--SPEAR_AllForc
trend['SPEAR_all_p'], pval['SPEAR_all_p'] = cal_trend_pval(1977,2015,years4a,ensmean4a)
trend_data['SPEAR_all_p'], dummy = cal_trends(1977,2015,years4a,alldata4a)

trend['SPEAR_all_f'], pval['SPEAR_all_f'] = cal_trend_pval(1977,2050,years4b,ensmean4b)
trend_data['SPEAR_all_f'], dummy = cal_trends(1977,2050,years4b,alldata4b)

#--SPEAR_NatForc
trend['SPEAR_nat_p'], pval['SPEAR_nat_p'] = cal_trend_pval(1977,2015,years5a,ensmean5a)
trend_data['SPEAR_nat_p'], dummy = cal_trends(1977,2015,years5a,alldata5a)

trend['SPEAR_nat_f'], pval['SPEAR_nat_f'] = cal_trend_pval(1977,2050,years5b,ensmean5b)
trend_data['SPEAR_nat_f'], dummy = cal_trends(1977,2050,years5b,alldata5b)

num1 = 0
num2 = 0
for exp in ["Obs","HiFLOR_all_p","SPEAR_all_p","HiFLOR_nat_p","SPEAR_nat_p",'dummy',"HiFLOR_all_f","SPEAR_all_f","SPEAR_nat_f"]:
    num1 = num1 + len(trend_data[exp])
    num2 = num2+1

columns = ["Experiment", "Trends"]
index1 = np.arange(num1)
index2 = np.arange(num2)
df = pd.DataFrame(columns=columns, index=index1)
dfmean = pd.DataFrame(columns=columns, index=index2)

colors = []
ccolors = {}
ii=0
jj=0
for exp in ["Obs","HiFLOR_all_p","SPEAR_all_p","HiFLOR_nat_p","SPEAR_nat_p",'dummy',"HiFLOR_all_f","SPEAR_all_f","SPEAR_nat_f"]:

     if exp == "dummy":
         colors.append("white")
         ccolors[exp] = "white"
     elif pval[exp] <=0.05 and trend[exp]>0:
         colors.append("red")
         ccolors[exp] = "red"
     elif pval[exp] <=0.05 and trend[exp]<0:
         colors.append("blue")
         ccolors[exp] = "blue"
     else:
         colors.append("gray")
         ccolors[exp] = "gray"

     for tt in trend_data[exp]:
           if exp=="Obs" or exp=='dummy':
               df.iloc[ii] = [exp, np.nan] # dummy
           else:
               df.iloc[ii] = [exp, tt]
           ii=ii+1

     dfmean.iloc[jj] = [exp, trend[exp]]
     jj=jj+1

#----------
# Plot
#----------
plt.figure(0, figsize=(15,5))
ax1 = plt.subplot(1, 2, 1) # set up panel plot
 
h1, pval1 = mk_test(alldata1a, alpha=0.05)
plt.plot(years1a, alldata1a, "o-", color="black", label="Observed (P-value=%.2f)" % (pval1))
if pval1 <=0.05:
    xx2, yy2, aa2 = linreg(years1a,alldata1a)
    plt.plot(xx2, yy2, "--", color="black", lw=3)

plt.fill_between(years2a, ensmin2a, ensmax2a, color="red", alpha=0.2)
h1, pval1 = mk_test(ensmean2a, alpha=0.05)
ax1.plot(years2a, ensmean2a, "o-", color="red", label="HiFLOR (P-value=%.2f)" % (pval1))
if pval1 <=0.05:
    xx1, yy1, aa1 = linreg(years2a,ensmean2a)
    plt.plot(xx1, yy1, "--", color="red", lw=3)
    
plt.legend(loc=2, prop={"size":14})
plt.xlabel('Year',fontsize=14)
plt.ylabel('Anomalous days',fontsize=14)
  
plt.title("(c) QSF Anomalous Days (1977\u20142015)",fontsize=16)
    
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.xlim(1975, 2017)
#plt.ylim(0,20)

#
# 1977-2100
#
ax1 = plt.subplot(1, 2, 2) # set up panel plot
  
sns.boxplot(y="Experiment",x="Trends", data=df, whis=1.5, color="c", palette=ccolors, boxprops=dict(alpha=0.3))

sns.stripplot(y="Experiment",x="Trends", data=df, jitter=True, size=4, color=".3",linewidth=0)

sns.stripplot(y="Experiment",x="Trends", data=dfmean, marker="s", jitter=False, size=14, palette=ccolors, linewidth=0)

ax1.set_yticklabels(labs)
   
ax1.axvline(x=0, linestyle='dashed', color="gray")

plt.title("(d) Linear Trend in QSF Anomalous Days",fontsize=16)
plt.xlabel('Days per year',fontsize=14)
    
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
   
plt.tight_layout()
plt.savefig("./Fig8cd.png")
